{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "unavailable-korea",
   "metadata": {},
   "source": [
    "# Molecular generation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "freelance-russell",
   "metadata": {},
   "source": [
    "In this tutorial, we will go through how to train a sequence VAE model for generating molecules with the format of SMILES sequence. In particular, we will demostrate how to train a VAE model and sample the generative molecules from a pre-trained model.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bound-amendment",
   "metadata": {},
   "source": [
    "## Sequence VAE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abstract-showcase",
   "metadata": {},
   "source": [
    "Molecular generation is a popular tool to produce new molecules by training deep generative models on large dataset. The generative models could be used for designing new molecules, exploring molecular space etc. The generative molecules could be further used for virtual screening or other downstream tasks. In this work, we will introduce a Variational Autoencoders (VAE) based generative model.\n",
    "\n",
    "VAE contains two neural nets - an encoder and a decoder. With this structure, the model could convert the high dimensional input space into a low dimensonal latent space by an encoder and convert back to original input space in order for construction by a decoder. The latent space is a continous vector space with normal distribution. We minimize both Kullback-Leibler(KL) divergence loss and reconstruction loss. With the nice property of continous latent space, we could sample the new molecules using the trained-VAE model.\n",
    "\n",
    "The input of molecules are the SMILES sequence. By combining both, the sequence VAE model will take a SMILES sequence as input and reconstruct the input sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "revolutionary-luxury",
   "metadata": {},
   "source": [
    "![title](./figures/seq_VAE.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "considerable-cherry",
   "metadata": {},
   "source": [
    "## Part I: Train a seq-VAE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "nonprofit-fleece",
   "metadata": {},
   "source": [
    "### Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "immediate-glenn",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "seq_VAE_path = '../apps/molecular_generation/seq_VAE/'\n",
    "sys.path.insert(0, os.getcwd() + \"/..\")\n",
    "sys.path.append(seq_VAE_path)\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "informative-familiar",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-05-13 14:38:50--  https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/molecular_generation/zinc_moses.tgz\n",
      "Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 10.70.0.165\n",
      "Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|10.70.0.165|:443... connected.\n",
      "HTTP request sent, awaiting response... "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/luohongyu01/opt/anaconda3/lib/python3.8/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200 OK\n",
      "Length: 8708409 (8.3M) [application/gzip]\n",
      "Saving to: ‘zinc_moses.tgz.1’\n",
      "\n",
      "zinc_moses.tgz.1    100%[===================>]   8.30M  2.54MB/s    in 3.5s    \n",
      "\n",
      "2021-05-13 14:38:54 (2.34 MB/s) - ‘zinc_moses.tgz.1’ saved [8708409/8708409]\n",
      "\n",
      "x zinc_moses/\n",
      "x zinc_moses/.DS_Store\n",
      "x zinc_moses/test.csv\n",
      "x zinc_moses/train.csv\n"
     ]
    }
   ],
   "source": [
    "# download and decompress the data\n",
    "!wget https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/molecular_generation/zinc_moses.tgz\n",
    "!tar -zxvf \"zinc_moses.tgz\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "chronic-billion",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = './zinc_moses/train.csv'\n",
    "train_data = load_zinc_dataset(data_path)\n",
    "# get the toy data\n",
    "train_data = train_data[0:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "noticed-minutes",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "powered-traveler",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1',\n",
       " 'CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1',\n",
       " 'Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO',\n",
       " 'Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C',\n",
       " 'CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O',\n",
       " 'CCOC(=O)c1cncn1C1CCCc2ccccc21',\n",
       " 'COc1ccccc1OC(=O)Oc1ccccc1OC',\n",
       " 'O=C1Nc2ccc(Cl)cc2C(c2ccccc2Cl)=NC1O',\n",
       " 'CN1C(=O)C(O)N=C(c2ccccc2Cl)c2cc(Cl)ccc21',\n",
       " 'CCC(=O)c1ccc(OCC(O)CO)c(OC)c1']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "south-madness",
   "metadata": {},
   "source": [
    "### define vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "sublime-carnival",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the sequence vocabuary based on dataset\n",
    "vocab = OneHotVocab.from_data(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "christian-power",
   "metadata": {},
   "source": [
    "### Model Configuration Settings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "rising-intellectual",
   "metadata": {},
   "source": [
    "The network is built up on hyperparameters from model_config."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "heated-observer",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = \\\n",
    "{\n",
    "    \"max_length\":80,     # max length of sequence\n",
    "    \"q_cell\": \"gru\",     # encoder RNN cell\n",
    "    \"q_bidir\": 1,        # if encoder is bidiretion\n",
    "    \"q_d_h\": 256,        # hidden size of encoder\n",
    "    \"q_n_layers\": 1,     # number of layers of encoder RNN\n",
    "    \"q_dropout\": 0.5,    # encoder drop out rate\n",
    "\n",
    "\n",
    "    \"d_cell\": \"gru\",     # decoder RNN cell\n",
    "    \"d_n_layers\":3,      # number of decoder layers\n",
    "    \"d_dropout\":0.2,     # decoder drop out rate\n",
    "    \"d_z\":128,           # latent space size\n",
    "    \"d_d_h\":512,         # hidden size of decoder\n",
    "    \"freeze_embeddings\":0 # if freeze embeddings\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "advisory-syndicate",
   "metadata": {},
   "source": [
    "### Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "binary-eugene",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build the model\n",
    "from pahelix.model_zoo.seq_vae_model  import VAE\n",
    "model = VAE(vocab, model_config)  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "consecutive-mouth",
   "metadata": {},
   "source": [
    "### Trian the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "physical-policy",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the training settings\n",
    "batch_size = 64\n",
    "learning_rate = 0.001\n",
    "n_epoch = 1\n",
    "kl_weight = 0.1\n",
    "\n",
    "# define optimizer\n",
    "optimizer = paddle.optimizer.Adam(parameters=model.parameters(),\n",
    "                            learning_rate=learning_rate)\n",
    "\n",
    "# build the dataset and data loader\n",
    "max_length = model_config[\"max_length\"]\n",
    "train_dataset = StringDataset(vocab, train_data, max_length)\n",
    "train_dataloader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "flexible-celtic",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#######################\n",
      "batch:0, kl_loss:0.377259, recon_loss:3.379486\n",
      "batch:1, kl_loss:0.259201, recon_loss:3.264177\n",
      "batch:2, kl_loss:0.210570, recon_loss:3.144137\n",
      "batch:3, kl_loss:0.205814, recon_loss:3.053869\n",
      "batch:4, kl_loss:0.204681, recon_loss:2.960207\n",
      "batch:5, kl_loss:0.205177, recon_loss:2.892930\n",
      "batch:6, kl_loss:0.203757, recon_loss:2.838837\n",
      "batch:7, kl_loss:0.201053, recon_loss:2.782497\n",
      "batch:8, kl_loss:0.197671, recon_loss:2.751050\n",
      "batch:9, kl_loss:0.192766, recon_loss:2.715708\n",
      "batch:10, kl_loss:0.186594, recon_loss:2.684680\n",
      "batch:11, kl_loss:0.179440, recon_loss:2.664472\n",
      "batch:12, kl_loss:0.171974, recon_loss:2.641148\n",
      "batch:13, kl_loss:0.164508, recon_loss:2.620756\n",
      "batch:14, kl_loss:0.157552, recon_loss:2.605232\n",
      "batch:15, kl_loss:0.151044, recon_loss:2.586791\n",
      "epoch:0 loss:2.601895 kl_loss:0.151044 recon_loss:2.586791\n"
     ]
    }
   ],
   "source": [
    "# start to train \n",
    "for epoch in range(n_epoch):\n",
    "    print('#######################')\n",
    "    kl_loss_values = []\n",
    "    recon_loss_values = []\n",
    "    loss_values = []\n",
    "    \n",
    "    for batch_id, data in enumerate(train_dataloader()):\n",
    "        # read batch data\n",
    "        data_batch = data\n",
    "\n",
    "        # forward\n",
    "        kl_loss, recon_loss  = model(data_batch)\n",
    "        loss = kl_weight * kl_loss + recon_loss\n",
    "\n",
    "\n",
    "        # backward\n",
    "        loss.backward()\n",
    "        # optimize\n",
    "        optimizer.step()\n",
    "        # clear gradients\n",
    "        optimizer.clear_grad()\n",
    "        \n",
    "        # gathering values from each batch\n",
    "        kl_loss_values.append(kl_loss.numpy())\n",
    "        recon_loss_values.append(recon_loss.numpy())\n",
    "        loss_values.append(loss.numpy())\n",
    "\n",
    "        \n",
    "        print('batch:%s, kl_loss:%f, recon_loss:%f' % (batch_id, float(np.mean(kl_loss_values)), float(np.mean(recon_loss_values))))\n",
    "        \n",
    "    print('epoch:%d loss:%f kl_loss:%f recon_loss:%f' % (epoch, float(np.mean(loss_values)), float(np.mean(kl_loss_values)),float(np.mean(recon_loss_values))),flush=True)\n",
    "\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "alpine-theta",
   "metadata": {},
   "source": [
    "## Part II: Sample from prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "serial-campaign",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'valid': 0.013000000000000012, 'unique@3': 0.6666666666666666, 'IntDiv': 0.7307692307692307, 'IntDiv2': 0.5181166128686162, 'Filters': 0.9230769230769231}\n"
     ]
    }
   ],
   "source": [
    "from pahelix.utils.metrics.molecular_generation.metrics_ import get_all_metrics\n",
    "N_samples = 1000  # number of samples \n",
    "max_len = 80      # maximum length of samples\n",
    "current_samples = model.sample(N_samples, max_len)  # get the samples from pre-trained model\n",
    "\n",
    "metrics = get_all_metrics(gen=current_samples, k=[3])  # get the evaluation from samples\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "boolean-cowboy",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
